`include "sdr_head.v"

module sdr_write(
    input                   clk,
    input                   write_en,
    input       [24 : 0]    int_addr,
    input       [31 : 0]    wdata,
    output reg              write_done,
    output      [19 : 0]    wr_bus,
    inout       [15 : 0]    sdr_dq
);

    localparam st0 = 3'd0;
    localparam st1 = 3'd1;
    localparam st2 = 3'd2;
    localparam st3 = 3'd3;
    localparam st4 = 3'd4;

    reg [12 : 0]    sdr_a;
    reg [1  : 0]    sdr_ba;
    reg [15 : 0]    dq;
    reg [3  : 0]    cmd;
    reg [2  : 0]    state;
    reg [7  : 0]    cnt;
    reg             sdr_cke;
    reg             dq_en;

    assign sdr_dq = dq_en ? dq : 16'hZZZZ;
    assign wr_bus = {sdr_cke, cmd, sdr_a, sdr_ba};

    always @(posedge clk)
    begin
        if (!write_en) begin
            cmd <= `NOP;
            write_done <= 1'b0;
            sdr_a <= 13'b0;
            sdr_ba <= 2'b0;
            sdr_cke <= 1'b1;
            dq_en <= 1'b0;
            dq <= 16'b0;
            cnt <= 8'b0;
            state <= st0;
        end
        else begin
            case (state)
                st0 : begin
                    cmd <= `ACT;
                    sdr_a <= int_addr[22:10];
                    sdr_ba <= int_addr[24:23];
                    dq_en <= 1'b1;
                    state <= st1;
                end
                st1 : begin
                    if (cnt < `TRCD) begin
                        cnt <= cnt + 1'b1;
                        cmd <= `NOP;
                    end
                    else begin
                        cmd <= `WR;
                        sdr_ba <= int_addr[24:23];
                        sdr_a[9 : 0] <= int_addr[9:0];
                        sdr_a[10] <= 1'b1;
                        dq <= wdata[15 : 0];
                        cnt <= 8'b0;
                        state <= st2;
                    end
                end
                st2 : begin
                    cmd <= `NOP;
                    dq <= wdata[31 : 16];
                    state <= st3;
                end
                st3 : begin
                    if (cnt < `TWR)
                        cnt <= cnt + 1'b1;
                    else begin
                        write_done <= 1'b1;
                        cnt <= 8'b0;
                        state <= st4;
                    end
                end
                st4 : begin
                    dq_en <= 1'b0;
                    write_done <= 1'b0;
                    state <= st4;
                end
                default : state <= st4;
            endcase
        end
    end

endmodule 